package db_client

import (
	"context"
	"database/sql"
	"fmt"
	"log"
	"net/netip"
	"slices"
	"strings"
	"time"

	"github.com/google/uuid"
	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgtype"
	"github.com/spf13/viper"
	pconstants "github.com/turbot/pipe-fittings/v2/constants"
	pqueryresult "github.com/turbot/pipe-fittings/v2/queryresult"
	"github.com/turbot/pipe-fittings/v2/utils"
	"github.com/turbot/steampipe/v2/pkg/constants"
	"github.com/turbot/steampipe/v2/pkg/db/db_common"
	"github.com/turbot/steampipe/v2/pkg/error_helpers"
	"github.com/turbot/steampipe/v2/pkg/query/queryresult"
	"github.com/turbot/steampipe/v2/pkg/statushooks"
	"golang.org/x/text/language"
	"golang.org/x/text/message"
)

// ExecuteSync implements Client
// execute a query against this client and wait for the result
func (c *DbClient) ExecuteSync(ctx context.Context, query string, args ...any) (*pqueryresult.SyncQueryResult, error) {
	// acquire a session
	sessionResult := c.AcquireSession(ctx)
	if sessionResult.Error != nil {
		return nil, sessionResult.Error
	}

	defer func() {
		// we need to do this in a closure, otherwise the ctx will be evaluated immediately
		// and not in call-time
		sessionResult.Session.Close(error_helpers.IsContextCanceled(ctx))
	}()
	return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, args...)
}

// ExecuteSyncInSession implements Client
// execute a query against this client and wait for the result
func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, args ...any) (*pqueryresult.SyncQueryResult, error) {
	if query == "" {
		return &pqueryresult.SyncQueryResult{}, nil
	}

	result, err := c.ExecuteInSession(ctx, session, nil, query, args...)
	if err != nil {
		return nil, error_helpers.WrapError(err)
	}

	syncResult := &pqueryresult.SyncQueryResult{Cols: result.Cols}
	for row := range result.RowChan {
		select {
		case <-ctx.Done():
		default:
			// save the first row error to return
			if row.Error != nil && err == nil {
				err = error_helpers.WrapError(row.Error)
			}
			syncResult.Rows = append(syncResult.Rows, row)
		}
	}
	if c.shouldFetchTiming() {
		syncResult.Timing = result.Timing.GetTiming()
	}

	return syncResult, err
}

// Execute implements Client
// execute the query in the given Context
// NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
func (c *DbClient) Execute(ctx context.Context, query string, args ...any) (*queryresult.Result, error) {
	// acquire a session
	sessionResult := c.AcquireSession(ctx)
	if sessionResult.Error != nil {
		return nil, sessionResult.Error
	}

	// define callback to close session when the async execution is complete
	closeSessionCallback := func() { sessionResult.Session.Close(error_helpers.IsContextCanceled(ctx)) }
	return c.ExecuteInSession(ctx, sessionResult.Session, closeSessionCallback, query, args...)
}

// ExecuteInSession implements Client
// execute the query in the given Context using the provided DatabaseSession
// ExecuteInSession assumes no responsibility over the lifecycle of the DatabaseSession - that is the responsibility of the caller
// NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, onComplete func(), query string, args ...any) (res *queryresult.Result, err error) {
	if query == "" {
		return queryresult.NewResult(nil), nil
	}

	// fail-safes
	if session == nil {
		return nil, fmt.Errorf("nil session passed to ExecuteInSession")
	}
	if session.Connection == nil {
		return nil, fmt.Errorf("nil database connection passed to ExecuteInSession")
	}
	startTime := time.Now()
	// get a context with a timeout for the query to execute within
	// we don't use the cancelFn from this timeout context, since usage will lead to 'pgx'
	// prematurely closing the database connection that this query executed in
	ctxExecute := c.getExecuteContext(ctx)

	var tx *sql.Tx

	defer func() {
		if err != nil {
			err = error_helpers.HandleQueryTimeoutError(err)
			// stop spinner in case of error
			statushooks.Done(ctxExecute)
			// error - rollback transaction if we have one
			if tx != nil {
				_ = tx.Rollback()
			}
			// in case of error call the onComplete callback
			if onComplete != nil {
				onComplete()
			}
		}
	}()

	// start query
	var rows pgx.Rows
	rows, err = c.startQueryWithRetries(ctxExecute, session, query, args...)
	if err != nil {
		return
	}

	colDefs, err := fieldDescriptionsToColumns(rows.FieldDescriptions(), session.Connection.Conn())
	if err != nil {
		return nil, err
	}

	result := queryresult.NewResult(colDefs)

	// read the rows in a go routine
	go func() {
		// define a callback which fetches the timing information
		// this will be invoked after reading rows is complete but BEFORE closing the rows object (which closes the connection)
		timingCallback := func() {
			c.getQueryTiming(ctxExecute, startTime, session, result.Timing)
		}

		// read in the rows and stream to the query result object
		c.readRows(ctxExecute, rows, result, timingCallback)

		// call the completion callback - if one was provided
		if onComplete != nil {
			onComplete()
		}
	}()

	return result, nil
}

func (c *DbClient) getExecuteContext(ctx context.Context) context.Context {
	queryTimeout := time.Duration(viper.GetInt(pconstants.ArgDatabaseQueryTimeout)) * time.Second
	// if timeout is zero, do not set a timeout
	if queryTimeout == 0 {
		return ctx
	}
	// create a context with a deadline
	shouldBeDoneBy := time.Now().Add(queryTimeout)
	//nolint:golint,lostcancel //we don't use this cancel fn because, pgx prematurely cancels the PG connection when this cancel gets called in 'defer'
	newCtx, _ := context.WithDeadline(ctx, shouldBeDoneBy)

	return newCtx
}

func (c *DbClient) getQueryTiming(ctx context.Context, startTime time.Time, session *db_common.DatabaseSession, resultChannel queryresult.TimingResultStream) {
	// do not fetch if timing is disabled, unless output not JSON
	if !c.shouldFetchTiming() {
		return
	}

	var timingResult = &queryresult.TimingResult{
		DurationMs: time.Since(startTime).Milliseconds(),
	}
	// disable fetching timing information to avoid recursion
	c.disableTiming.Store(true)

	// whatever happens, we need to reenable timing, and send the result back with at least the duration
	defer func() {
		c.disableTiming.Store(false)
		resultChannel.SetTiming(timingResult)
	}()

	// load the timing summary
	summary, err := c.loadTimingSummary(ctx, session)
	if err != nil {
		log.Printf("[WARN] getQueryTiming: failed to read scan metadata, err: %s", err)
		return
	}

	// only load the individual scan  metadata if output is JSON or timing is verbose
	var scans []*queryresult.ScanMetadataRow
	if c.shouldFetchVerboseTiming() {
		scans, err = c.loadTimingMetadata(ctx, session)
		if err != nil {
			log.Printf("[WARN] getQueryTiming: failed to read scan metadata, err: %s", err)
			return
		}
	}

	// populate hydrate calls and rows fetched
	timingResult.Initialise(summary, scans)
}

func (c *DbClient) loadTimingSummary(ctx context.Context, session *db_common.DatabaseSession) (*queryresult.QueryRowSummary, error) {
	var summary = &queryresult.QueryRowSummary{}
	err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error {
		query := fmt.Sprintf(`select uncached_rows_fetched,
cached_rows_fetched,
hydrate_calls, 
scan_count,
connection_count from %s.%s `, constants.InternalSchema, constants.ForeignTableScanMetadataSummary)
		//query := fmt.Sprintf("select id, 'table' as table, cache_hit, rows_fetched, hydrate_calls, start_time, duration, columns, 'limit' as limit, quals from %s.%s where id > %d", constants.InternalSchema, constants.ForeignTableScanMetadata, session.ScanMetadataMaxId)
		rows, err := tx.Query(ctx, query)
		if err != nil {
			return err
		}

		// scan into summary
		summary, err = pgx.CollectOneRow(rows, pgx.RowToAddrOfStructByName[queryresult.QueryRowSummary])
		// no rows counts as an error
		if err != nil {
			return err
		}
		return nil
	})
	return summary, err
}

func (c *DbClient) loadTimingMetadata(ctx context.Context, session *db_common.DatabaseSession) ([]*queryresult.ScanMetadataRow, error) {
	var scans []*queryresult.ScanMetadataRow

	err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error {
		query := fmt.Sprintf(`
select connection,
"table",
cache_hit, 
rows_fetched, 
hydrate_calls, 
start_time,
duration_ms,
columns,
"limit",
quals from %s.%s order by duration_ms desc`, constants.InternalSchema, constants.ForeignTableScanMetadata)
		rows, err := tx.Query(ctx, query)
		if err != nil {
			return err
		}

		scans, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[queryresult.ScanMetadataRow])
		return err
	})
	return scans, err
}

// run query in a goroutine, so we can check for cancellation
// in case the client becomes unresponsive and does not respect context cancellation
func (c *DbClient) startQuery(ctx context.Context, conn *pgx.Conn, query string, args ...any) (rows pgx.Rows, err error) {
	doneChan := make(chan bool)
	go func() {
		// start asynchronous query
		rows, err = conn.Query(ctx, query, args...)
		close(doneChan)
	}()

	select {
	case <-doneChan:
	case <-ctx.Done():
		err = ctx.Err()
	}
	return
}

func (c *DbClient) readRows(ctx context.Context, rows pgx.Rows, result *queryresult.Result, timingCallback func()) {
	// defer this, so that these get cleaned up even if there is an unforeseen error
	defer func() {
		// we are done fetching results. time for display. clear the status indication
		statushooks.Done(ctx)
		// call the timing callback BEFORE closing the rows
		timingCallback()
		// close the sql rows object
		rows.Close()
		if err := rows.Err(); err != nil {
			result.StreamError(err)
		}
		// close the channels in the result object
		result.Close()

	}()

	rowCount := 0
Loop:
	for rows.Next() {
		select {
		case <-ctx.Done():
			statushooks.SetStatus(ctx, "Cancelling query")
			break Loop
		default:
			rowResult, err := readRow(rows, result.Cols)
			if err != nil {
				// the error will be streamed in the defer
				break Loop
			}

			// TACTICAL
			// determine whether to stop the spinner as soon as we stream a row or to wait for completion
			if isStreamingOutput() {
				statushooks.Done(ctx)
			}

			result.StreamRow(rowResult)

			// update the status message with the count of rows that have already been fetched
			// this will not show if the spinner is not active
			statushooks.SetStatus(ctx, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount)))
			rowCount++
		}
	}
}

func readRow(rows pgx.Rows, cols []*pqueryresult.ColumnDef) ([]interface{}, error) {
	columnValues, err := rows.Values()
	if err != nil {
		return nil, error_helpers.WrapError(err)
	}
	return populateRow(columnValues, cols)
}

func populateRow(columnValues []interface{}, cols []*pqueryresult.ColumnDef) ([]interface{}, error) {
	result := make([]interface{}, len(columnValues))
	for i, columnValue := range columnValues {
		if columnValue != nil {
			result[i] = columnValue
			switch cols[i].DataType {
			case "_TEXT":
				if arr, ok := columnValue.([]interface{}); ok {
					elements := utils.Map(arr, func(e interface{}) string { return e.(string) })
					result[i] = strings.Join(elements, ",")
				}
			case "INET":
				if inet, ok := columnValue.(netip.Prefix); ok {
					result[i] = strings.TrimSuffix(inet.String(), "/32")
				}
			case "UUID":
				if bytes, ok := columnValue.([16]uint8); ok {
					if u, err := uuid.FromBytes(bytes[:]); err == nil {
						result[i] = u
					}
				}
			case "TIME":
				if t, ok := columnValue.(pgtype.Time); ok {
					result[i] = time.UnixMicro(t.Microseconds).UTC().Format("15:04:05")
				}
			case "INTERVAL":
				if interval, ok := columnValue.(pgtype.Interval); ok {
					var sb strings.Builder
					years := interval.Months / 12
					months := interval.Months % 12
					if years > 0 {
						sb.WriteString(fmt.Sprintf("%d %s ", years, utils.Pluralize("year", int(years))))
					}
					if months > 0 {
						sb.WriteString(fmt.Sprintf("%d %s ", months, utils.Pluralize("mon", int(months))))
					}
					if interval.Days > 0 {
						sb.WriteString(fmt.Sprintf("%d %s ", interval.Days, utils.Pluralize("day", int(interval.Days))))
					}
					if interval.Microseconds > 0 {
						d := time.Duration(interval.Microseconds) * time.Microsecond
						formatStr := time.Unix(0, 0).UTC().Add(d).Format("15:04:05")
						sb.WriteString(formatStr)
					}
					result[i] = sb.String()
				}

			case "NUMERIC":
				if numeric, ok := columnValue.(pgtype.Numeric); ok {
					if f, err := numeric.Float64Value(); err == nil {
						result[i] = f.Float64
					}
				}
			}
		}
	}
	return result, nil
}

func isStreamingOutput() bool {
	outputFormat := viper.GetString(pconstants.ArgOutput)

	return slices.Contains([]string{constants.OutputFormatCSV, constants.OutputFormatLine}, outputFormat)
}

func humanizeRowCount(count int) string {
	p := message.NewPrinter(language.English)
	return p.Sprintf("%d", count)
}
